package divert

import (
	"errors"
	"log"
	"unsafe"

	"gitee.com/general252/dllimports"
	"golang.org/x/sys/windows"
)

// https://reqrypt.org/windivert-doc.html#divert_open

var (
	_winDivert      = dllimports.NewLazyDLL("WinDivert.dll")
	_WinDivertOpen  = _winDivert.NewProc("WinDivertOpen")
	_WinDivertClose = _winDivert.NewProc("WinDivertClose")
	_WinDivertRecv  = _winDivert.NewProc("WinDivertRecv")
	_WinDivertSend  = _winDivert.NewProc("WinDivertSend")
)

type WinDivert struct {
	handle uintptr
	buffer []byte
}

func NewWinDivert() *WinDivert {
	return &WinDivert{
		buffer: make([]byte, WINDIVERT_MTU_MAX),
	}
}

func (c *WinDivert) WinDivertOpen(
	filter string,
	layer WINDIVERT_LAYER,
	priority WINDIVERT_PRIORITY,
	flags WINDIVERT_FLAGS) error {
	/*
		HANDLE WinDivertOpen(
		    __in const char *filter,
		    __in WINDIVERT_LAYER layer,
		    __in INT16 priority,
		    __in UINT64 flags
		);
	*/

	value := stringToUTF8Ptr(filter)
	r0, _, err := _WinDivertOpen.Call(value, uintptr(layer), uintptr(priority), uintptr(flags))
	if !errors.Is(err, windows.NTE_OP_OK) {
		return err
	}

	c.handle = r0
	return nil
}

func (c *WinDivert) WinDivertClose() {
	_, _, _ = _WinDivertClose.Call(c.handle)
}

// ReadNetwork 类似pcap
func (c *WinDivert) ReadNetwork() ([]byte, *WinDivertDataNetwork, error) {
	var addr WinDivertDataNetwork
	packet, err := c.read(uintptr(unsafe.Pointer(&addr)))
	if err != nil {
		return nil, nil, err
	}

	return packet, &addr, nil
}

// ReadDataFlow 流建立事件
func (c *WinDivert) ReadDataFlow() ([]byte, *WinDivertDataFlow, error) {
	var addr WinDivertDataFlow
	packet, err := c.read(uintptr(unsafe.Pointer(&addr)))
	if err != nil {
		return nil, nil, err
	}

	return packet, &addr, nil
}

// ReadDataSocket socket事件
func (c *WinDivert) ReadDataSocket() ([]byte, *WinDivertDataSocket, error) {
	var addr WinDivertDataSocket
	packet, err := c.read(uintptr(unsafe.Pointer(&addr)))
	if err != nil {
		return nil, nil, err
	}

	return packet, &addr, nil
}

// ReadDataReflect 其它divert事件
func (c *WinDivert) ReadDataReflect() ([]byte, *WinDivertDataReflect, error) {
	var addr WinDivertDataReflect
	packet, err := c.read(uintptr(unsafe.Pointer(&addr)))
	if err != nil {
		return nil, nil, err
	}

	return packet, &addr, nil
}

func (c *WinDivert) WinDivertSend(packet []byte, addr *WinDivertAddress) (n int, err error) {
	/*
		BOOL WinDivertSend(
		    __in HANDLE handle,
		    __in const VOID *pPacket,
		    __in UINT packetLen,
		    __out_opt UINT *pSendLen,
		    __in const WINDIVERT_ADDRESS *pAddr
		);
	*/
	data := bytesToUTF8Ptr(packet)
	sendLen := uint32(0)
	r0, _, err := _WinDivertSend.Call(c.handle, data, uintptr(len(packet)), uintptr(unsafe.Pointer(&sendLen)), uintptr(unsafe.Pointer(addr)))
	if r0 != 1 {
		return 0, err
	}
	return int(sendLen), nil
}

func (c *WinDivert) read(addr uintptr) ([]byte, error) {
	/*
		BOOL WinDivertRecv(
		    __in HANDLE handle,
		    __out_opt PVOID pPacket,
		    __in UINT packetLen,
		    __out_opt UINT *pRecvLen,
		    __out_opt WINDIVERT_ADDRESS *pAddr
		);
	*/
	buffer := c.buffer
	bufferSize := uint32(WINDIVERT_MTU_MAX)
	packetLen := uint32(0)

	_, _, err := _WinDivertRecv.Call(c.handle,
		uintptr(unsafe.Pointer(&buffer[0])), uintptr(bufferSize),
		uintptr(unsafe.Pointer(&packetLen)),
		addr,
	)
	if !errors.Is(err, windows.NTE_OP_OK) {
		log.Println(err)
		return nil, err
	}

	return buffer[:packetLen], nil
}

func stringToUTF8Ptr(s string) uintptr {
	return bytesToUTF8Ptr([]byte(s))
}

func bytesToUTF8Ptr(s []byte) uintptr {
	temp := s
	utf8StrArr := make([]uint8, len(temp)+1) // +1是因为Lazarus中PChar为0结尾
	copy(utf8StrArr, temp)
	return uintptr(unsafe.Pointer(&utf8StrArr[0]))
}
